Goto

Collaborating Authors

 variance reduction


Multi-Head Attention as Ensemble Nadaraya-Watson Estimation: Variance Reduction, Decorrelation, and Optimal Head Diversity

arXiv.org Machine Learning

We develop a rigorous statistical theory of multi-head attention (MHA) as an ensemble of Nadaraya-Watson (NW) kernel regression estimators. Building on the algebraic identity between single-head softmax attention and the NW estimator, we prove that MHA is a structured ensemble of H NW estimators, each operating in a distinct learned projection subspace of the key space. We derive an explicit Bias-Variance-Covariance decomposition of the MHA mean squared error, showing that variance reduction depends not merely on the number of heads H but fundamentally on the decorrelation of head outputs. Decorrelation is governed by the principal angles between learned projection subspaces: orthogonal projections yield maximum variance reduction; aligned projections yield none. We introduce the Head Diversity Index (HDI), a computable spectral measure of inter-head decorrelation, and prove that MHA mean squared error is monotonically decreasing in HDI. This provides the first rigorous theoretical explanation for the empirically observed specialization of attention heads. Under a fixed total-dimension budget D = H * d_k, we solve the optimal head-dimension allocation problem, deriving the MSE-minimizing pair (H*, d_k*) from data distribution and regression smoothness. The solution yields a new architectural scaling law: the optimal per-head dimension grows logarithmically with training set size, while the optimal number of heads grows nearly linearly with the total budget D. Our framework unifies three strands of prior work: the NW theory of single-head attention, the general weighting theory for ensemble learning, and the decorrelation-variance-reduction isomorphism between biological and computational ensembles. Multi-head attention is the Transformer's instantiation of a universal principle: identical agents plus diversity-enforcing mechanisms yields emergent optimality.


Rennala MVR: Improved Time Complexity for Parallel Stochastic Optimization via Momentum-Based Variance Reduction

arXiv.org Machine Learning

Large-scale machine learning models are trained on clusters of machines that exhibit heterogeneous performance due to hardware variability, network delays, and system-level instabilities. In such environments, time complexity rather than iteration complexity becomes the relevant performance metric for optimization algorithms. Recent work by Tyurin and Richtárik [2023] established the first time complexity analysis for parallel first-order stochastic optimization, proposing Rennala SGD as a time-optimal method for smooth nonconvex optimization. However, Rennala SGD is fundamentally a modification of SGD, and variance reduction techniques are known to improve the iteration complexity of SGD. In this work, we investigate whether variance reduction can also improve time complexity in heterogeneous systems. We show that, under a mean-squared smoothness assumption, variance reduction can improve time complexity in relevant parameter regimes. To this end, we propose Rennala MVR, a variance-reduced extension of Rennala SGD based on momentum-based variance reduction, and analyze its oracle and time complexity. We establish lower bounds for time complexity under these assumptions.



Knowledge Distillation Performs Partial Variance Reduction

Neural Information Processing Systems

Knowledge distillation is a popular approach for enhancing the performance of "student" models, with lower representational capacity, by taking advantage of more powerful "teacher" models. Despite its apparent simplicity and widespread use, the underlying mechanics behind knowledge distillation (KD) are still not fully understood. In this work, we shed new light on the inner workings of this method, by examining it from an optimization perspective. We show that, in the context of linear and deep linear models, KD can be interpreted as a novel type of stochastic variance reduction mechanism. We provide a detailed convergence analysis of the resulting dynamics, which hold under standard assumptions for both strongly-convex and non-convex losses, showing that KD acts as a form of partial variance reduction, which can reduce the stochastic gradient noise, but may not eliminate it completely, depending on the properties of the "teacher" model. Our analysis puts further emphasis on the need for careful parametrization of KD, in particular w.r.t. the weighting of the distillation loss, and is validated empirically on both linear models and deep neural networks.


Sharp Analysis of Stochastic Optimization under Global Kurdyka-Łojasiewicz Inequality

Neural Information Processing Systems

We study the complexity of finding the global solution to stochastic nonconvex optimization when the objective function satisfies global Kurdyka-Łojasiewicz (KŁ) inequality and the queries from stochastic gradient oracles satisfy mild expected smoothness assumption. We first introduce a general framework to analyze Stochastic Gradient Descent (SGD) and its associated nonlinear dynamics under the setting. As a byproduct of our analysis, we obtain a sample complexity of O(ϵ (4 α)/α) for SGD when the objective satisfies the so called α-PŁ condition, where α is the degree of gradient domination. Furthermore, we show that a modified SGD with variance reduction and restarting (PAGER) achieves an improved sample complexity of O(ϵ 2/α)when the objective satisfies the average smoothness assumption. This leads to the first optimal algorithm for the important case of α = 1 which appears in applications such as policy optimization in reinforcement learning.


Machine Learning for Variance Reduction in Online Experiments

Neural Information Processing Systems

We consider the problem of variance reduction in randomized controlled trials, through the use of covariates correlated with the outcome but independent of the treatment. We propose a machine learning regression-adjusted treatment effect estimator, which we call MLRATE. MLRATE uses machine learning predictors of the outcome to reduce estimator variance. It employs cross-fitting to avoid overfitting biases, and we prove consistency and asymptotic normality under general conditions. MLRATE is robust to poor predictions from the machine learning step: if the predictions are uncorrelated with the outcomes, the estimator performs asymptotically no worse than the standard difference-in-means estimator, while if predictions are highly correlated with outcomes, the efficiency gains are large. In A/A tests, for a set of 48 outcome metrics commonly monitored in Facebook experiments the estimator has over 70% lower variance than the simple differencein-means estimator, and about 19% lower variance than the common univariate procedure which adjusts only for pre-experiment values of the outcome.



Improving Machine Learning Performance with Synthetic Augmentation

arXiv.org Machine Learning

Synthetic augmentation is increasingly used to mitigate data scarcity in financial machine learning, yet its statistical role remains poorly understood. We formalize synthetic augmentation as a modification of the effective training distribution and show that it induces a structural bias--variance trade-off: while additional samples may reduce estimation error, they may also shift the population objective whenever the synthetic distribution deviates from regions relevant under evaluation. To isolate informational gains from mechanical sample-size effects, we introduce a size-matched null augmentation and a finite-sample, non-parametric block permutation test that remains valid under weak temporal dependence. We evaluate this framework in both controlled Markov-switching environments and real financial datasets, including high-frequency option trade data and a daily equity panel. Across generators spanning bootstrap, copula-based models, variational autoencoders, diffusion models, and TimeGAN, we vary augmentation ratio, model capacity, task type, regime rarity, and signal-to-noise. We show that synthetic augmentation is beneficial only in variance-dominant regimes, such as persistent volatility forecasting-while it deteriorates performance in bias-dominant settings, including near-efficient directional prediction. Rare-regime targeting can improve domain-specific metrics but may conflict with unconditional permutation inference. Our results provide a structural perspective on when synthetic data improves financial learning performance and when it induces persistent distributional distortion.



Conditional neural control variates for variance reduction in Bayesian inverse problems

arXiv.org Machine Learning

Bayesian inference for inverse problems involves computing expectations under posterior distributions -- e.g., posterior means, variances, or predictive quantities -- typically via Monte Carlo (MC) estimation. When the quantity of interest varies significantly under the posterior, accurate estimates demand many samples -- a cost often prohibitive for partial differential equation-constrained problems. To address this challenge, we introduce conditional neural control variates, a modular method that learns amortized control variates from joint model-data samples to reduce the variance of MC estimators. To scale to high-dimensional problems, we leverage Stein's identity to design an architecture based on an ensemble of hierarchical coupling layers with tractable Jacobian trace computation. Training requires: (i) samples from the joint distribution of unknown parameters and observed data; and (ii) the posterior score function, which can be computed from physics-based likelihood evaluations, neural operator surrogates, or learned generative models such as conditional normalizing flows. Once trained, the control variates generalize across observations without retraining. We validate our approach on stylized and partial differential equation-constrained Darcy flow inverse problems, demonstrating substantial variance reduction, even when the analytical score is replaced by a learned surrogate.